package directlinks

import (
	"context"
	"errors"
	"fmt"
	"io"
	"net/http"
	"path/filepath"
	"sync/atomic"

	"github.com/charmbracelet/log"
	"github.com/duke-git/lancet/v2/retry"
	"github.com/krau/SaveAny-Bot/common/utils/fsutil"
	"github.com/krau/SaveAny-Bot/common/utils/ioutil"
	"github.com/krau/SaveAny-Bot/config"
	"github.com/krau/SaveAny-Bot/pkg/enums/ctxkey"
	"golang.org/x/sync/errgroup"
)

func (t *Task) Execute(ctx context.Context) error {
	logger := log.FromContext(ctx)
	logger.Infof("Starting directlinks task %s", t.ID)
	if t.Progress != nil {
		t.Progress.OnStart(ctx, t)
	}
	// head all links to get file info
	eg, gctx := errgroup.WithContext(ctx)
	eg.SetLimit(config.C().Workers)
	fetchedTotalBytes := atomic.Int64{}
	for _, file := range t.files {
		eg.Go(func() error {
			req, err := http.NewRequestWithContext(ctx, http.MethodHead, file.URL, nil)
			if err != nil {
				return fmt.Errorf("failed to create HEAD request for %s: %w", file.URL, err)
			}
			resp, err := t.client.Do(req)
			if err != nil {
				return fmt.Errorf("failed to HEAD %s: %w", file.URL, err)
			}
			defer resp.Body.Close()
			if resp.StatusCode < 200 || resp.StatusCode >= 300 {
				return fmt.Errorf("HEAD %s returned status %d", file.URL, resp.StatusCode)
			}
			fetchedTotalBytes.Add(resp.ContentLength)
			file.Size = resp.ContentLength
			if name := resp.Header.Get("Content-Disposition"); name != "" {
				// Set file name
				filename := parseFilename(name)
				file.Name = filename
			}

			return nil
		})
	}
	err := eg.Wait()
	if err != nil {
		logger.Errorf("Error during HEAD requests: %v", err)
		if t.Progress != nil {
			t.Progress.OnDone(ctx, t, err)
		}
		return err
	}
	t.totalBytes = fetchedTotalBytes.Load()
	// start downloading
	eg, gctx = errgroup.WithContext(ctx)
	eg.SetLimit(config.C().Workers)
	for _, file := range t.files {
		eg.Go(func() error {
			t.processingMu.RLock()
			if _, ok := t.processing[file.URL]; ok {
				return fmt.Errorf("file %s is already being processed", file.URL)
			}
			t.processingMu.RUnlock()
			t.processingMu.Lock()
			t.processing[file.URL] = file
			t.processingMu.Unlock()
			defer func() {
				t.processingMu.Lock()
				delete(t.processing, file.URL)
				t.processingMu.Unlock()
			}()
			err := t.processLink(gctx, file)
			t.downloaded.Add(1)
			if errors.Is(err, context.Canceled) {
				logger.Debug("Link processing canceled")
				return err
			}
			if err != nil {
				logger.Errorf("Error processing link %s: %v", file.URL, err)
				return fmt.Errorf("failed to process link %s: %w", file.URL, err)
			}
			return nil
		})
	}
	err = eg.Wait()
	if err != nil {
		logger.Errorf("Error during directlinks task execution: %v", err)
	} else {
		logger.Infof("Directlinks task %s completed successfully", t.ID)
	}
	if t.Progress != nil {
		t.Progress.OnDone(ctx, t, err)
	}
	return err
}

func (t *Task) processLink(ctx context.Context, file *File) error {
	logger := log.FromContext(ctx)
	err := retry.Retry(func() error {
		req, err := http.NewRequestWithContext(ctx, http.MethodGet, file.URL, nil)
		if err != nil {
			return fmt.Errorf("failed to create GET request for %s: %w", file.URL, err)
		}
		resp, err := t.client.Do(req)
		if err != nil {
			return fmt.Errorf("failed to GET %s: %w", file.URL, err)
		}
		defer resp.Body.Close()
		if resp.StatusCode < 200 || resp.StatusCode >= 300 {
			return fmt.Errorf("GET %s returned status %d", file.URL, resp.StatusCode)
		}
		ctx = context.WithValue(ctx, ctxkey.ContentLength, file.Size)
		if t.stream {
			return t.Storage.Save(ctx, resp.Body, filepath.Join(t.StorPath, file.Name))
		}
		cacheFile, err := fsutil.CreateFile(filepath.Join(config.C().Temp.BasePath,
			fmt.Sprintf("direct_%s_%s", t.ID, file.Name)))
		if err != nil {
			return fmt.Errorf("failed to create temp file: %w", err)
		}
		defer func() {
			if err := cacheFile.CloseAndRemove(); err != nil {
				logger.Errorf("Failed to close and remove cache file: %v", err)
			}
		}()
		wr := ioutil.NewProgressWriter(cacheFile, func(n int) {
			t.downloadedBytes.Add(int64(n))
			if t.Progress != nil {
				t.Progress.OnProgress(ctx, t)
			}
		})

		copyResultCh := make(chan error, 1)
		go func() {
			_, err := io.Copy(wr, resp.Body)
			copyResultCh <- err
		}()
		select {
		case err := <-copyResultCh:
			if err != nil {
				return fmt.Errorf("failed to copy file %s to cache file: %w", file.URL, err)
			}
		case <-ctx.Done():
			return ctx.Err()
		}
		_, err = cacheFile.Seek(0, 0)
		if err != nil {
			return fmt.Errorf("failed to seek cache file for resource %s: %w", file.URL, err)
		}
		return t.Storage.Save(ctx, cacheFile, filepath.Join(t.StorPath, file.Name))
	}, retry.RetryTimes(uint(config.C().Retry)), retry.Context(ctx))
	if ctx.Err() != nil {
		return ctx.Err()
	}
	return err
}
